Skip to content

[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596

Merged
sudhakarsingh27 merged 32 commits into
NVIDIA:mainfrom
sudhakarsingh27:flash_attn_pad_bw_seqs
May 23, 2026
Merged

[PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen)#2596
sudhakarsingh27 merged 32 commits into
NVIDIA:mainfrom
sudhakarsingh27:flash_attn_pad_bw_seqs

Conversation

@sudhakarsingh27
Copy link
Copy Markdown
Member

@sudhakarsingh27 sudhakarsingh27 commented Jan 14, 2026

Description

TLDR

Enable pad_between_seqs=True for FlashAttention 3 with THD format — both for context parallelism (A2A and P2P comm types) and non-CP paths. Previously pad_between_seqs was only supported with FusedAttention.

Problem

When using THD format with variable-length sequences, sequences are padded for divisibility across CP ranks. With pad_between_seqs=True, the attention kernel needs to know actual (unpadded) token counts so it doesn't compute attention over padding tokens. FusedAttention already handled this via cu_seqlens_q_padded, but FlashAttention (both FA2 and FA3) had pad_between_seqs hardcoded to False in the CP path, and FA2 was entirely disabled for pad_between_seqs + thd. FA3 can natively handle this via its seqused_q/seqused_k mechanism.

Solution

Use FA3's seqused_q/seqused_k tensors to communicate actual token counts per batch element. Pass cu_seqlens_q_padded for tensor memory layout while deriving seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] from the real cu_seqlens. This applies to both the CP path (A2A and P2P) and the non-CP path.

Fixes #2399

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

Please list the changes introduced in this PR:

context_parallel.py

  • get_fa_args(): Add seqused_q/seqused_k parameters, pass through to FA3 forward and backward positional arg lists (replacing hardcoded Nones).
  • cp_p2p_fwd_flash_attn() / cp_p2p_bwd_flash_attn(): Accept pad_between_seqs, cu_seqlens_q_padded, cu_seqlens_kv_padded. When enabled, derive seqused tensors and override cu_seqlens to padded versions (with half-padding for lower-triangle/upper-triangle sections).
  • AttnFuncWithCPAndKVP2P: Thread pad_between_seqs and padded cu_seqlens through all forward/backward cp_p2p_fwd/bwd_flash_attn call sites. Save ctx.pad_between_seqs for backward.
  • AttnFuncWithCPAndQKVOA2A.forward(): Add pad_between_seqs parameter. When enabled with FA3+THD, derive seqused and swap cu_seqlens for padded versions before calling get_fa_args().
  • AttnFuncWithCPAndQKVOA2A.backward(): Same seqused/cu_seqlens override. Use zeros_like (not empty_like) for gradient init when pad_between_seqs since FA3 skips padding positions. Add extra None in return tuple for the new pad_between_seqs gradient slot.
  • attn_forward_func_with_cp(): Pass pad_between_seqs in A2A args list.

backends.py

  • FlashAttention.forward(): Accept cu_seqlens_q_padded/cu_seqlens_kv_padded. Detect pad_between_seqs by comparing padded vs actual cu_seqlens. Pass padded cu_seqlens to CP path. For non-CP FA3 path, derive and pass seqused_q/seqused_k.

dot_product_attention.py

  • Pass cu_seqlens_q_padded/cu_seqlens_kv_padded through to FlashAttention.

utils.py

  • Only disable FA2 (not FA3) when pad_between_seqs + thd. FA3 handles this natively via seqused.

test_attention_with_cp.py

  • Add @pytest.mark.parametrize("pad_between_seqs", [False, True]) to flash attention CP tests.
  • Skip pad_between_seqs=True for non-THD formats, when FA3 is not installed, and for a2a+p2p comm type (not yet supported).

run_attention_with_cp.py

  • Thread pad_between_seqs through generate_input_shapes() and run_dpa_with_cp().
  • When pad_between_seqs, set cu_seqlens_q to actual lengths (not just for FusedAttention).
  • Handle FA3 backward NaN at padding positions: nan_to_num(nan=0.0).
  • Zero padding positions explicitly before comparison (FA3 doesn't guarantee zeros at padding slots).
  • Add tensor names to NaN/Inf assertion messages for debuggability.

test_attention.py

  • Group FlashAttention with FusedAttention for padded input/output handling in _run_dot_product_attention() (previously FlashAttention used original unpadded inputs).
  • Pass cu_seqlens_q_padded/cu_seqlens_kv_padded and pad_between_seqs to DPA call for FlashAttention backend.
  • Add pad_between_seqs=True to parametrize with skip for non-THD formats.

New Tests

CP tests (test_attention_with_cp.py)

Added @pytest.mark.parametrize("pad_between_seqs", [False, True]) to test_cp_with_flash_attention. Skip conditions: non-THD formats, FA3 not installed, a2a+p2p comm type.

5 new tests that run (all pad_between_seqs=True, thd, bf16):

Test CP comm Model config
True-p2p-thd-cp_1_0-bf16 P2P causal, 1 head
True-p2p-thd-cp_2_1-bf16 P2P causal, 2 heads
True-a2a-thd-cp_1_0-bf16 A2A causal, 1 head
True-a2a-thd-cp_1_2-bf16 A2A causal, sliding window
True-a2a-thd-cp_2_1-bf16 A2A causal, 2 heads

Non-CP tests (test_attention.py)

Added True to @pytest.mark.parametrize("pad_between_seqs", [False, True]) on test_dot_product_attention, with skip for non-THD. Also changed _run_dot_product_attention so FlashAttention uses padded inputs/cu_seqlens and receives pad_between_seqs=True.

48 new test IDs collected, but all are skipped because the main parametrize uses qkv_layout=None (defaults to sbhd, not thd). The non-CP pad_between_seqs + FA3 code path is exercised indirectly when other test functions call test_dot_product_attention with qkv_layout="thd_thd_thd" (e.g., test_dpa_softmax_thd).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sudhakarsingh27 sudhakarsingh27 self-assigned this Jan 14, 2026
@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Jan 14, 2026

Greptile Summary

This PR extends pad_between_seqs=True support to FlashAttention 3 with THD (variable-length) format, covering both the non-CP path and CP paths (P2P and A2A). The mechanism uses FA3's seqused_q/seqused_k parameters to communicate actual per-batch token counts while passing cu_seqlens_q_padded as the tensor memory-layout descriptor.

  • Core logic (context_parallel.py, backends.py): get_fa_args() now threads seqused_q/seqused_k through to FA3 forward and backward; P2P and A2A CP paths derive seqused from per-step unpadded cu_seqlens and override the layout descriptor to the padded version; backward gradient buffers are initialized with zeros_like instead of empty_like so FA3 tile-spillover at padding positions leaves zeros rather than garbage.
  • Backend selection (utils.py): FA2 and FA4 are correctly disabled for pad_between_seqs+THD; FA3 is now the only non-fused backend allowed for this combination.
  • Tests (test_attention_with_cp.py, run_attention_with_cp.py, test_attention.py): new pad_between_seqs parametrize added with appropriate skip guards; however, the test driver passes the bool directly while run_dpa_with_cp compares against the string "True", so all five new CP test cases silently run as pad_between_seqs=False. Additionally, the NaN/Inf assertions in run_attention_with_cp.py execute before the padding-position zeroing code, which would cause spurious failures if FA3 writes NaN to padding slots once the bool/string issue is corrected.

Confidence Score: 3/5

The production attention kernels are wired correctly for the supported P2P and A2A paths, but the new tests do not actually exercise the new code path due to a bool/string type mismatch, and a secondary ordering problem in the test runner would cause those tests to fail even after the type mismatch is fixed.

Multiple issues compound in the test layer: every pad_between_seqs=True CP test silently falls back to pad_between_seqs=False because run_dpa_with_cp compares the bool True against the string "True". Fixing that comparison would then expose the NaN-check-before-zeroing ordering problem, where FA3 tile-spillover NaN values in reference tensors trigger the assertion at line 562 before the zeroing code at line 277 can run. The net effect is that the feature ships without any working test coverage of the new code path. Additionally, all_gather CP mode silently ignores pad_between_seqs, and a2a+p2p lacks a production-time guard.

tests/pytorch/attention/test_attention_with_cp.py and tests/pytorch/attention/run_attention_with_cp.py need attention for the bool/string type mismatch and NaN assertion ordering; transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py needs a runtime guard or error for the all_gather+pad_between_seqs and a2a+p2p+pad_between_seqs combinations.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Core CP logic for pad_between_seqs: adds seqused_q/k to FA3 P2P and A2A paths, pre-zeros backward gradient buffers, but the all_gather branch silently drops pad_between_seqs and no production guard exists for a2a+p2p.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds seqused_q/k to non-CP FA3 path and correctly swaps cu_seqlens→cu_seqlens_padded as the memory layout descriptor; missing qkv_format=="thd" guard on the seqused block could affect non-THD callers that set pad_between_seqs=True.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Correctly disables FA2 and FA4 for pad_between_seqs+THD; also adds non-paged KV-cache FA2 disable for max_seqlen_kv%256≠0 and narrows the deterministic FA3 head_dim guard to training-only.
tests/pytorch/attention/run_attention_with_cp.py Test runner for CP attention: adds fa_pad_between_seqs parameter and padding-zeroing logic, but NaN/Inf assertions run before the zeroing code (will abort on FA3 NaN tile-spillover), and FusedAttention padding zero-checks are silently removed.
tests/pytorch/attention/test_attention_with_cp.py Adds pad_between_seqs parametrize to FA CP tests with correct skip guards; passes bool pad_between_seqs to run_dpa_with_cp as fa_pad_between_seqs, but run_dpa_with_cp compares with "True" (string), so all five new pad_between_seqs=True tests silently execute as pad_between_seqs=False.
tests/pytorch/attention/test_attention.py Adds pad_between_seqs=True to test_dot_product_attention parametrize (all new cases skip since qkv_layout=None); correctly moves FlashAttention to use padded inputs/cu_seqlens matching FusedAttention treatment.
qa/L1_pytorch_distributed_unittest/test.sh Moves CP tests earlier and parallelizes them across GPU sets when ≥8 GPUs are available; clean shell scripting with proper error propagation.
qa/L3_pytorch_FA_versions_test/test.sh Adds error-accumulation pattern replacing set -e, supports per-FA-version XML output, and runs CP tests only with the designated FA version per architecture.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["DotProductAttention.forward()"] --> B{cp_group?}
    B -- No --> C["FlashAttention.forward()"]
    B -- Yes --> D["attn_forward_func_with_cp()"]

    C --> E{pad_between_seqs + FA3?}
    E -- Yes --> F["Append cu_seqlens_q_padded as layout descriptor + set seqused_q/k kwargs"]
    E -- No --> G["Append cu_seqlens_q as layout descriptor"]

    D --> H{cp_comm_type}
    H -- p2p / a2a+p2p --> I["AttnFuncWithCPAndKVP2P (pad_between_seqs ✓)"]
    H -- a2a --> J["AttnFuncWithCPAndQKVOA2A (pad_between_seqs ✓)"]
    H -- all_gather --> K["AttnFuncWithCPAndKVAllGather (pad_between_seqs ✗ not forwarded)"]

    I --> L["cp_p2p_fwd_flash_attn(): seqused=diff(cu_seqlens_per_step), cu_seqlens→padded"]
    J --> M["get_fa_args(seqused_q, seqused_k): cu_seqlens→cu_seqlens_padded"]

    style K fill:#ffcccc
Loading

Reviews (50): Last reviewed commit: "Merge branch 'main' into flash_attn_pad_..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from ea51821 to e338049 Compare March 10, 2026 23:37
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L2

@sudhakarsingh27 sudhakarsingh27 changed the title Flash attn pad bw seqs [PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD Mar 11, 2026
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Outdated
Comment thread transformer_engine/pytorch/attention/dot_product_attention/backends.py Outdated
Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L1

Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from b0a3c64 to 057f406 Compare April 9, 2026 05:18
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

1 similar comment
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from 00bdc92 to 0f48ebc Compare April 10, 2026 15:04
Comment thread qa/L3_pytorch_FA_versions_test/test.sh Outdated
Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
Comment thread tests/pytorch/attention/test_attention.py Outdated
if not FlashAttentionUtils.v3_is_installed:
pytest.skip("pad_between_seqs with CP requires Flash Attention v3!")
if cp_comm_type == "a2a+p2p":
pytest.skip("pad_between_seqs is not yet supported with A2A+P2P CP comm type!")
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What about AG?

if pad_between_seqs:
dq, dk, dv = [torch.zeros_like(x) for x in [q_part, k_part, v_part]]
else:
dq, dk, dv = [torch.empty_like(x) for x in [q_part, k_part, v_part]]
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, we can't do this for fwd, right? Because fwd output is not allocated by us.

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's a limitation in Flash Attention code - forward never mutates out (so pre-zeroing is overwritten), backward treats dq/dk/dv as in-place mutable (so pre-zeroing sticks). Also this zeroing out works only for CP code where we can provide the args.

None of the zeroing works for non-CP path because we only have the forward call in TE.

 FA3 / Hopper (hopper/flash_attn_interface.py)
- Forward: mutates_args=() _ namespace flash_attn_3::_flash_attn_forward 
- Backward: mutates_args=("dq", "dk", "dv") _ namespace flash_attn_3::_flash_attn_backward

@ptrendx ptrendx added this to the 2.15 milestone Apr 23, 2026
Comment thread qa/L3_pytorch_FA_versions_test/test.sh Outdated
Comment thread qa/L3_pytorch_FA_versions_test/test.sh Outdated
Comment thread tests/pytorch/attention/run_attention_with_cp.py Outdated
Comment thread qa/L3_pytorch_FA_versions_test/test.sh Outdated
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

Add support for padding between sequences (pad_between_seqs) in the
FlashAttention 3 backend when used with context parallelism (CP).

Key changes:
- backends.py: Pass fa_pad_between_seqs through to FA3 forward/backward
- context_parallel.py: Handle pad_between_seqs in A2A and P2P CP paths,
  zero FA3 padding garbage in CP forward, fix a2a backward alignment
- dot_product_attention.py: Auto-detect pad_between_seqs from cu_seqlens
- utils.py: Gate FA3 deterministic backward for hdim>=256, fix
  flash_attn_supported override for cross-attention and large head_dim,
  disable UnfusedDotProductAttention for pad_between_seqs, add SM100+
  FA3 skip

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add test parametrization for pad_between_seqs in flash attention tests.
Update run_attention_with_cp.py to support the new parameter and fix
batch boundary alignment in the non-CP FA3 path. Run tests in parallel
when multiple GPUs are available.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Add deterministic CP test runs to L3 FA versions test. Support TE_PATH
positional arg and fix GPU threshold for parallel test execution.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…raint

The previous check disabled FA3 for deterministic mode whenever
head_dim_qk > 128, which was overly conservative — FA3 forward supports
deterministic execution at any head dim. The actual constraint from
flash_api.cpp is that the backward pass does not support deterministic
mode when max(head_size, head_size_v) >= 256.

Narrow the gate to only disable FA3 during training (backward) and
raise the threshold to >= 256, checking both head_dim_qk and head_dim_v
to handle MLA configs with asymmetric head dimensions.

Ref: https://github.com/Dao-AILab/flash-attention/blob/ac6f2eb5/hopper/flash_api.cpp#L1370

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from 9c01601 to 4745f98 Compare April 24, 2026 23:02
The pad_between_seqs gate in get_attention_backend only disabled
FlashAttention 2, letting FA4 leak through to the test-time
fused-vs-flash comparison. On B200 runners that install flash-attn-4,
this caused test_dpa_qkv_layout_thd to compare FusedAttention against
an FA4 output whose padded positions contain garbage, producing 48
numerics failures in L3_pytorch_FA_versions_test--B200_1GPU.

The log message already claimed FA4 would be disabled — this change
makes the code match the message: set use_flash_attention_4 = False
alongside use_flash_attention_2 when pad_between_seqs is True. FA3
continues to support pad_between_seqs via seqused_q/seqused_k.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 changed the title [PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD [PyTorch] Add pad_between_seqs support for non-CP and CP (A2A and P2P) with FA3 + THD (varlen) Apr 24, 2026
sudhakarsingh27 added a commit to sudhakarsingh27/TransformerEngine that referenced this pull request May 15, 2026
Each (world_size) is served by one long-lived torchrun running
run_attention_with_cp_pool.py. Tests submit work over rank-0 stdin
as JSON and read results from rank-0 stdout, replacing the
per-test torchrun launch path. NCCL init/destroy happens once
per pool, not once per case, eliminating ~9s overhead per test
and fixing L3 timeouts.

Why two pool sizes: cp_comm_type="a2a+p2p" needs world_size=4;
everything else uses world_size=2. We can't resize an active PG, so
one pool per world_size, routed by num_gpus. Pools spawn lazily on
first use so a session that only exercises 2-GPU cases never pays
the 4-GPU init cost.

Includes:
- PoolWorker class with sentinel-prefixed JSON protocol over rank-0
  stdio (sentinel filters out torchrun status / library prints that
  share the stdout fd)
- Stderr ring buffer (200 lines / ~4 KB tail) attached to crash-path
  AssertionErrors so CI JUnit XML shows the real failure cause
- POOL_SUBMIT_TIMEOUT_SEC defaulting to 90 s (~6x p50 case wall on
  H100); override via NVTE_CP_POOL_TIMEOUT_SEC
- Stream race fix on max_logit_per_step in all-gather CP forward:
  wait_stream(flash_attn_streams[i-1]) before torch.maximum, so the
  read on the default stream doesn't race with the write on cp_stream
  in iteration i=2. The pool's persistent process exposed this latent
  race; per-process subprocess design happened to schedule it safely.
- Deep-copy of model_configs_flash_attn[model] to prevent in-place
  attn_mask_type mutation from leaking across pool cases
- Deterministic-mode skips for FusedAttention configs that OOM on
  sm90 under NVTE_ALLOW_NONDETERMINISTIC_ALGO=0

Preserves PR NVIDIA#2596 pad_between_seqs additions (fa_pad_between_seqs
parameter through generate_input_shapes and run_dpa_with_cp, THD
padding cleanup for FA3 tile-spillover comparison).

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch 2 times, most recently from bcb717e to bce64ff Compare May 15, 2026 21:38
sudhakarsingh27 added a commit to sudhakarsingh27/TransformerEngine that referenced this pull request May 15, 2026
Re-applying the formatting fixes that pre-commit.ci posted on
PR NVIDIA#2596 after the previous push (commit bcb717e, overwritten
by the cleanup rebase).

for more information, see https://pre-commit.ci

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

2 similar comments
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

…to flash_attn_pad_bw_seqs

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

# Conflicts:
#	tests/pytorch/attention/test_attention_with_cp.py
@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from 287403f to d8e8ba4 Compare May 22, 2026 01:02
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

Comment thread qa/L3_pytorch_FA_versions_test/test.sh Outdated
Comment thread qa/L3_pytorch_FA_versions_test/test.sh Outdated
Address two pending review comments:
1. The "auto-set when RUN_L3_TESTS=1" annotation on the base-image FA3
   preinstall is no longer accurate; drop it so readers don't grep for a
   coupling that doesn't exist.
2. `flash_attn_interface` reads like a generic FA API even though the
   top-level shim is only created by the FA3 install. Switching to
   `import flash_attn_3` makes the FA3-specific intent unambiguous and
   matches the FA3 package layout produced by the source build.

Local validation on H100 (sm90) with FA3 active, TE worktree resolving
to the editable install (verified via three-layer import check from
/tmp): test_attention_with_cp.py parallel det+nondet — 45 passed / 0
failed nondet (3:52), 33 passed / 0 failed det (2:55). 33 pad-True
nondet passes + 21 pad-True det passes confirm the FA3+THD+CP path is
exercised; 5 det OOM cases skip cleanly via the existing inline guard.

Same test scope is exercised by L1_pytorch_distributed_unittest
(parallel det+nondet) and the FA3 iteration of L3_pytorch_FA_versions_test;
the changes here are L3-only documentation/detection tweaks and do not
alter the Python test code, but the L1+L3 CP execution was re-run on
the cleaned PR head end-to-end as proof.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
KshitijLakhani
KshitijLakhani previously approved these changes May 22, 2026
Comment on lines +426 to +431
if qkv_format == "thd" and config.num_heads >= 20 and get_device_compute_capability() == (9, 0):
pytest.skip(
"Deterministic FusedAttention backward with THD format OOMs on sm90"
" for this particular test config since cuDNN reserves memory"
" proportional to bHSS (known cuDNN issue)."
)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The motivation for this makes sense to me but seems like the way we are skipping the test is viewing it from a slightly narrow lens. What I mean by that is the main issue is total memory (bhSS) but we seem to be guarding on head dims only

This skip guard would not be correct if tomorrow someone were to add a test with small b,S and H>20 (IIUC) - it almost makes it seem that the issue is the num_heads rather than the total memory

Is there a better way to do this ?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch — gated on the actual b*H*S*S product instead of num_heads in d3bd4e4. Threshold of 1e9 empirically matches the existing 5-case skip set on the test_essential fused subset (cp_2_0, cp_2_2, cp_3_1, cp_4_2, cp_4_3 — bHSS 1.07B–4.29B) and lets the smaller configs (cp_1_0/cp_2_1/cp_2_4/cp_3_2/cp_3_4, all ~0.40B) keep running. Local det+nondet still 33/0 + 45/0 with 5 OOM skips fired by the new gate.

Comment thread qa/L3_pytorch_FA_versions_test/test.sh Outdated
Comment thread qa/L1_pytorch_distributed_unittest/test.sh
1. Det FusedAttention backward THD/sm90 OOM skip: gate on the actual
   memory pressure (b*H*S*S) instead of num_heads >= 20. The cuDNN
   workspace is proportional to bHSS, so a future config with H >= 20
   but small b or S would be needlessly skipped under the old guard,
   while a config with H < 20 but large b*S that hit the same OOM
   wouldn't be caught. Threshold 1e9 empirically matches the existing
   5-case skip set on the test_essential fused subset (cp_2_0, cp_2_2,
   cp_3_1, cp_4_2, cp_4_3 — bHSS in 1.07B–4.29B) and lets cp_1_0/
   cp_2_1/cp_2_4/cp_3_2/cp_3_4 (bHSS ~0.40B) keep running.

2. L3 FA3 install comment: drop the "Dockerfile.base INSTALL_FA3=1"
   reference. The detection check is the contract; mentioning a
   specific image variable couples this script to an out-of-tree
   provisioning detail that may evolve independently.

Local validation on H100 (sm90) with FA3 active and TE worktree
resolving to editable (verified via /tmp-cwd three-layer import check
after reinstall — the /usr/local TE shadow had reappeared between
sessions): test_attention_with_cp.py parallel det+nondet — 45 passed /
0 failed nondet (4:09), 33 passed / 0 failed det (3:14). 33 pad-True
nondet passes + 21 pad-True det passes; 5 det OOM cases skip via the
new bHSS gate — same cases as the old num_heads-only gate.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27
Copy link
Copy Markdown
Member Author

/te-ci pytorch L3

@KshitijLakhani KshitijLakhani self-requested a review May 22, 2026 21:23
KshitijLakhani
KshitijLakhani previously approved these changes May 22, 2026
Comment thread tests/pytorch/attention/test_attention_with_cp.py Outdated
…ation

Address review nits on the deterministic THD-backward OOM guard:
1. Replace the magic number 1_000_000_000 with the named constant
   SM90_DET_FUSED_THD_BWD_MAX_BHSS = 1 << 30, so the value is searchable
   and labeled.
2. Replace the prefatory comment with a short note tying the number to
   cuDNN's actual workspace request (~128 * bHSS bytes, measured on
   cuDNN 9.21.0 sm90 — see local sweep). At bHSS = 1<<30 the request is
   128 GiB, which doesn't fit on H100's 80 GB.
3. Flag the b>=3 caveat for future readers: cuDNN rounds the batch up
   internally so workspace grows super-linearly past b=2 (b=4 asks for
   4x the b=2 workspace, not 2x). The current fused-essential matrix is
   all b=2, so the threshold stays correct for what the test exercises;
   the note is there so the next person doesn't have to rediscover it.

Skip set is unchanged — cp_2_0, cp_2_1, cp_3_1, cp_4_2, cp_4_3.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
We measured the workspace request from outside cuDNN, so the comment should
say "observed" rather than asserting what cuDNN does. Reframes the ~128 *
bHSS bytes formula and the super-linear b>=3 behavior as empirical
observations from our sweep.

No code change.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@KshitijLakhani KshitijLakhani self-requested a review May 23, 2026 01:17
Copy link
Copy Markdown
Collaborator

@KshitijLakhani KshitijLakhani left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@sudhakarsingh27 sudhakarsingh27 merged commit 80ea313 into NVIDIA:main May 23, 2026
12 of 13 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support FlashAttention with pad_between_seqs=True

4 participants